#include "bitwise_left_shift_tiling.h"
#include "register/op_def_registry.h"
#include "tiling/platform/platform_ascendc.h"
#include "graph/utils/type_utils.h"
#include <algorithm>
#include <vector>
const uint64_t BLOCK_SIZE = 256;
const uint64_t BUFFER_NUM = 1;
namespace optiling {
static ge::graphStatus TilingFunc(gert::TilingContext* context)
{
    auto input_shape = context->GetInputShape(0)->GetStorageShape();
    auto other_shape = context->GetInputShape(1)->GetStorageShape();
    uint32_t inputNum = context->GetInputShape(0)->GetStorageShape().GetShapeSize(); 
    uint32_t inputBytes = GetSizeByDataType(context->GetInputDesc(0)->GetDataType()); 
    uint32_t inputLength = inputBytes * inputNum; 
    uint32_t otherNum = context->GetInputShape(1)->GetStorageShape().GetShapeSize(); 
    {
        BitwiseLeftShiftTilingData tiling;
        uint64_t ubLength = 0;
        uint32_t bigCoreDataNum = 0;
        uint32_t bigCoreLoopNum = 0;
        uint32_t bigCoreTailDataNum = 0;
        auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
        ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubLength);
        auto coreNum = ascendcPlatform.GetCoreNum();
        uint32_t inputDataNum = context->GetInputShape(0)->GetStorageShape().GetShapeSize();
        uint32_t dataTypeLength = 0;
        ge::TypeUtils::GetDataTypeLength(context->GetInputDesc(0)->GetDataType(), dataTypeLength);
        uint32_t inputLength = inputDataNum * dataTypeLength;
        uint32_t ubPartNum = (dataTypeLength == 1) ? 12 : 6;
        uint32_t ubPartLength = ubLength / ubPartNum / BUFFER_NUM;
        uint32_t ubPartBlockNum = ubPartLength / BLOCK_SIZE;
        uint32_t ubPartDataNum = (ubPartBlockNum * BLOCK_SIZE) / dataTypeLength;
        uint32_t inputLengthAlign32 = (((inputLength + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE);
        coreNum = (coreNum <  inputLengthAlign32 / BLOCK_SIZE) ? coreNum : inputLengthAlign32 / BLOCK_SIZE;
        uint32_t everyCoreInputBlockNum = inputLengthAlign32 / BLOCK_SIZE / coreNum;
        uint32_t tailBlockNum = (inputLengthAlign32 / BLOCK_SIZE) % coreNum;
        uint32_t smallCoreDataNum = everyCoreInputBlockNum * BLOCK_SIZE / dataTypeLength;
        uint32_t smallCoreLoopNum = smallCoreDataNum / ubPartDataNum;
        smallCoreLoopNum = (everyCoreInputBlockNum % ubPartBlockNum) == 0 ? smallCoreLoopNum : smallCoreLoopNum + 1;
        uint32_t smallCoreTailDataNum = smallCoreDataNum - ubPartDataNum * (smallCoreLoopNum-1);
        smallCoreTailDataNum = smallCoreTailDataNum == 0 ? ubPartDataNum : smallCoreTailDataNum;
        everyCoreInputBlockNum += 1;
        bigCoreDataNum = everyCoreInputBlockNum * BLOCK_SIZE / dataTypeLength;
        bigCoreLoopNum = bigCoreDataNum / ubPartDataNum;
        bigCoreLoopNum = (everyCoreInputBlockNum % ubPartBlockNum) == 0 ? bigCoreLoopNum : bigCoreLoopNum + 1;
        bigCoreTailDataNum = bigCoreDataNum - ubPartDataNum * (bigCoreLoopNum-1);
        bigCoreTailDataNum = bigCoreTailDataNum == 0 ? ubPartDataNum : bigCoreTailDataNum;
        tiling.set_smallCoreDataNum(smallCoreDataNum);
        tiling.set_bigCoreDataNum(bigCoreDataNum);
        tiling.set_ubPartDataNum(ubPartDataNum);
        tiling.set_smallCoreTailDataNum(smallCoreTailDataNum);
        tiling.set_bigCoreTailDataNum(bigCoreTailDataNum);
        tiling.set_smallCoreLoopNum(smallCoreLoopNum);
        tiling.set_bigCoreLoopNum(bigCoreLoopNum);
        tiling.set_tailBlockNum(tailBlockNum);
        context->SetBlockDim(coreNum);
        context->SetTilingKey(1);
        tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
        context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());
        size_t *currentWorkspace = context->GetWorkspaceSizes(1);
        currentWorkspace[0] = 0;
    }
    return ge::GRAPH_SUCCESS;
}
}
namespace ge {
static ge::graphStatus InferShape(gert::InferShapeContext* context)
{
    const gert::Shape* x1_shape = context->GetInputShape(0);
    gert::Shape* y_shape = context->GetOutputShape(0);
    *y_shape = *x1_shape;
    return GRAPH_SUCCESS;
}
static ge::graphStatus InferDataType(gert::InferDataTypeContext *context)
{
const auto inputDataType = context->GetInputDataType(0);
context->SetOutputDataType(0, inputDataType);
return ge::GRAPH_SUCCESS;
}
}
namespace ops {
class BitwiseLeftShift : public OpDef {
public:
    explicit BitwiseLeftShift(const char* name) : OpDef(name)
    {
        this->Input("input")
            .ParamType(REQUIRED)
            .DataType({ge::DT_INT8, ge::DT_INT16, ge::DT_INT32, ge::DT_INT64})
            .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
        this->Input("other")
            .ParamType(REQUIRED)
            .DataType({ge::DT_INT8, ge::DT_INT16, ge::DT_INT32, ge::DT_INT64})
            .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
        this->Output("out")
            .ParamType(REQUIRED)
            .DataType({ge::DT_INT8, ge::DT_INT16, ge::DT_INT32, ge::DT_INT64})
            .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
        this->SetInferShape(ge::InferShape).SetInferDataType(ge::InferDataType);
        this->AICore()
            .SetTiling(optiling::TilingFunc);
        this->AICore().AddConfig("ascend910b");
    }
};
OP_ADD(BitwiseLeftShift);
}
